import pandas as p
from BuildStataInputData import load_vehicle_data, load_household_data
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns


year = 2017
year = 2009

for year in [2017, 2009]:
    joint_data = p.read_csv('CleanData/FullNHTS{}.csv'.format(year))
    zero_veh_indices = joint_data.loc[:, 'num_vehicles']<=0

    # joint_data = joint_data.loc[~zero_veh_indices, :]
    joint_data.loc[:, 'Household Class'] = 'Single-Vehicle'
    joint_data.loc[joint_data.loc[:, 'num_vehicles']>1, 'Household Class'] = 'Multi-Vehicle'
    joint_data.loc[:, 'Policy Status'] = 'Subject to\nSafety Inspections'
    joint_data.loc[joint_data.loc[:, 'HasSafety']==False, 'Policy Status'] = 'Not Subject to\nSafety Inspections'
    joint_data.loc[:, 'Travel/Driver'] = joint_data.loc[:, 'sum_total_miles']/joint_data.loc[:, 'DRVRCNT']
    joint_data.loc[:, 'Travel/Vehicle'] = joint_data.loc[:, 'sum_total_miles']/joint_data.loc[:, 'num_vehicles']

    x = 4
    gr = (1+np.sqrt(5))/2
    y = x/gr

    vars_to_plot = ['Travel/Driver', 'Travel/Vehicle', 'num_vehicles', 'sum_total_miles']
    vars_to_plot = ['Travel/Vehicle']
    titles_to_plot = [f'NHTS {year}\nTravel (VMT)/Driver vs. Policy Status\nby Household Vehicle Count', f'NHTS {year}\nTravel (VMT)/Vehicle vs. Policy Status\nby Household Vehicle Count', '', '']
    save_files = ['travel_per_driver', 'travel_per_vehicle', 'vehicle_count', 'sum_total_miles']
    save_files = ['travel_per_vehicle']

    title_dict = dict(zip(vars_to_plot, titles_to_plot))
    save_file_dict = dict(zip(vars_to_plot, save_files))

    # for temp_var in ['Travel/Driver', 'Travel/Vehicle', 'all_vehicle_num_vehicles']:
    for temp_var in vars_to_plot:
        nan_indices = p.isna(joint_data.loc[:, temp_var])
        temp_mean = np.average(joint_data.loc[~nan_indices, temp_var], weights=joint_data.loc[~nan_indices, 'WTHHFIN'])
        print(f'NHTS {year} {temp_var} average: {temp_mean}')
        joint_data = joint_data.loc[~zero_veh_indices, :]
        if('all_vehicle_num_vehicles' in temp_var):
            pass
        else:
            fig, ax = plt.subplots(1, 1, figsize=(x, y))
            sns.barplot(joint_data.sort_values('Policy Status'), y=temp_var, hue='Household Class', x='Policy Status', weights='WTHHFIN',
                        hue_order=['Single-Vehicle', 'Multi-Vehicle'], palette=['xkcd:red', 'xkcd:blue'], alpha=0.5, legend=False)
            plt.text(-.2, 2000, 'Single-\nVehicle', ha='center', va='center')
            plt.text(.2, 2000, 'Multi-\nVehicle', ha='center', va='center')
            plt.text(-.2+1, 2000, 'Single-\nVehicle', ha='center', va='center')
            plt.text(.2+1, 2000, 'Multi-\nVehicle', ha='center', va='center')
            plt.title(title_dict[temp_var])
            plt.savefig(f'Plots/{save_file_dict[temp_var]}_{year}.png', bbox_inches='tight')
            plt.show()